# scripts/apply_composite_moment.py
#!/usr/bin/env python3
"""
Apply the 4D composite moment operator to each field sample.
Reads HDF5 samples, applies `composite_moment_4d`, and writes out new HDF5 group and dataset.
"""
import os
import argparse
import json
import numpy as np
import h5py
from ar_sim.common.fractal_fits import load_D_values
from ar_sim.common.composite_moment import composite_moment_4d


def main():
    parser = argparse.ArgumentParser(
        description="Apply Composite Moment Operator (4D) to field samples"
    )
    parser.add_argument(
        "--samples", type=str, default="results/samples.h5",
        help="Path to input HDF5 file with dataset 'samples'"
    )
    parser.add_argument(
        "--pivot-params", type=str, default="data/pivot_params.json",
        help="Path to JSON file with pivot parameters"
    )
    parser.add_argument(
        "--sigma", type=float, default=1.0,
        help="Kernel width sigma parameter"
    )
    parser.add_argument(
        "--output", type=str, default="results/composite.h5",
        help="Path to output HDF5 file for composite moment samples"
    )
    args = parser.parse_args()

    # Load field samples
    with h5py.File(args.samples, 'r') as infile:
        samples = infile['samples'][:]

    # Load context levels and D_values
    n_vals, D_vals, _ = load_D_values()

    # Load pivot parameters and attach D_vals
    with open(args.pivot_params, 'r') as f:
        pivot = json.load(f)
    pivot['D_vals'] = D_vals.tolist()

    # Apply composite moment to each sample
    composite_samples = []
    for field in samples:
        comp = composite_moment_4d(field, n_vals, pivot, sigma=args.sigma)
        composite_samples.append(comp)
    composite_samples = np.stack(composite_samples)

    # Ensure output directory exists
    out_dir = os.path.dirname(args.output)
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Write out to HDF5 with a new group
    with h5py.File(args.output, 'w') as outfile:
        grp = outfile.create_group('composite')
        grp.create_dataset('samples', data=composite_samples)

    print(f"Wrote {len(composite_samples)} composite-moment samples to {args.output}")

if __name__ == '__main__':
    main()
